

# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# maximize the two variationally dependant pieces: p(Y | .) & p(M | .)
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

optimize_reparam <- function(dat, idx_test, fmla_f, fmla_m, px, beta_start){

 # Prepare the data
 Xm = as.matrix(model.matrix(fmla_m, data=model.frame(dat, na.action = NULL)))
 Xf = as.matrix(model.matrix(fmla_f, data=model.frame(dat, na.action = NULL)))
 
 # Initial values for beta
 if (length(beta_start) == 0){
  beta_start = rep(0.1, ncol(Xm) + ncol(Xf) + 1) # added 1 for w_0 
  names(beta_start) = c(colnames(Xm), colnames(Xf), "w0")
 }
 
 # Define the negative log likelihood function   
 eval_f <- function(beta, dat, idx_test, Xm, Xf, px){
  n = nrow(dat)
  p = length(beta)
  beta_m = beta[1:ncol(Xm)]
  beta_f = beta[(ncol(Xm)+1):(p-1)]
  w0 = beta[p]
  wa = 0
  names(beta_m) = colnames(Xm)
  names(beta_f) = colnames(Xf)
  names(w0) = c("w0")
  names(wa) = c("wa")
  beta_y = c(beta_f, w0, wa)
  M = dat$M
  Y = dat$Y
  
  Y_hat = estimate_Y(dat, beta_y, beta_m, px)
  Y[idx_test] = Y_hat[idx_test]
  p_Y = dnorm(Y, Y_hat, 1)
  
  p_M1 = 1/(1+exp(-Xm%*%beta_m))
  p_M = M*p_M1 + (1-M)*(1-p_M1)
  
  f = sum(log(p_M) + log(p_Y))
  
  # f = sum(-M*log(1+exp(-Xm%*%beta_m))-(1-M)*log(1+exp(Xm%*%beta_m))) + sum(-(Y - Y_hat)^2/2)
   return(-f/n)
 }

 # Solve the optimization problem
 mle_res = nloptr(x0=beta_start, 
              eval_f=eval_f, 
              opts = list("algorithm"="NLOPT_LN_COBYLA","xtol_rel"=1.0e-8, "maxeval"=50000),
              dat=dat, idx_test=idx_test, Xm=Xm, Xf=Xf, px=px)
 
 # Returnt the parameters
 beta = mle_res$solution
 p = length(beta)
 beta_m = beta[1:ncol(Xm)]
 beta_y = c(beta[(ncol(Xm)+1):p], 0)

 names(beta_m) = colnames(Xm)
 names(beta_y) = c(colnames(Xf), "w0", "wa")

 neg_log_lik = eval_f(beta, dat, idx_test, Xm, Xf, px)
 log_lik_YM = - neg_log_lik*n #- n*log(sqrt(2*pi))
 
 return(list(beta_m = beta_m, 
             beta_y = beta_y, 
             mle = log_lik_YM))
}


